#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F



class GraphConvolution(nn.Module):
    def __init__(self, n_in, n_out, bias=True):
        super(GraphConvolution, self).__init__()
        self.n_in  = n_in
        self.n_out = n_out
        self.linear = nn.Linear(n_in,  n_out)
    def forward(self, x, adj):
        out = self.linear(x)
        return F.elu(torch.spmm(adj, out))


class GCN(nn.Module):
    def __init__(self, nfeat, nhid, layers, dropout):
        super(GCN, self).__init__()
        self.layers = layers
        self.nhid = nhid
        self.gcs = nn.ModuleList()
        self.gcs.append(GraphConvolution(nfeat,  nhid))
        self.dropout = nn.Dropout(dropout)
        for i in range(layers-1):
            self.gcs.append(GraphConvolution(nhid,  nhid))
    def forward(self, x, adjs):

        for idx in range(len(self.gcs)):
            x = self.dropout(self.gcs[idx](x, adjs[idx]))
        return x

class SuGCN(nn.Module):
    def __init__(self, encoder, num_classes, dropout, inp):
        super(SuGCN, self).__init__()
        self.encoder = encoder
        self.dropout = nn.Dropout(dropout)
        self.linear  = nn.Linear(self.encoder.nhid, num_classes)
    def forward(self, feat, adjs):
        x = self.encoder(feat, adjs)
        x = self.dropout(x)
        x = self.linear(x)
        return x


# for efficient full batch inference
class GCN_full(nn.Module):
    def __init__(self, nfeat, nhid, layers, dropout):
        super(GCN_full, self).__init__()
        self.layers = layers
        self.nhid = nhid
        self.gcs = nn.ModuleList()
        self.gcs.append(GraphConvolution(nfeat,  nhid))
        self.dropout = nn.Dropout(dropout)
        for i in range(layers-1):
            self.gcs.append(GraphConvolution(nhid,  nhid))
    def forward(self, x, adj):
        '''
            The difference here with the original GCN implementation is that
            we will receive different adjacency matrix for different layer.
        '''
        for idx in range(len(self.gcs)):
            x = self.dropout(self.gcs[idx](x, adj))
        return x

class SuGCN_full(nn.Module):
    def __init__(self, encoder, num_classes, dropout, inp):
        super(SuGCN_full, self).__init__()
        self.encoder = encoder
        self.dropout = nn.Dropout(dropout)
        self.linear  = nn.Linear(self.encoder.nhid, num_classes)
    def forward(self, feat, adj):
        # use adj instead of adjs for full batch
        x = self.encoder(feat, adj)
        x = self.dropout(x)
        x = self.linear(x)
        return x

# For sketching and recording HW norm.
class SKGCN(nn.Module):
    def __init__(self, nfeat, nhid, layers, dropout, n_nodes):
        
        # n_nodes is the number of record_HWing nodes

        super(SKGCN, self).__init__()
        self.layers = layers
        self.nhid = nhid
        # we do not create the whole 
        self.lins = nn.ModuleList()
        self.lins.append(nn.Linear(nfeat,  nhid))
        self.dropout = nn.Dropout(dropout)
        for i in range(layers-1):
            self.lins.append(nn.Linear(nhid,  nhid))

        self.HW_row_norm = torch.zeros([layers, n_nodes], requires_grad = False)


    def forward(self, x, adjs, after_nodes_ls = [], record_HW = False):
        '''
            In forward of SKGCN, we record 
        '''
        for idx in range(len(self.lins)):
            x = self.lins[idx](x)
            if record_HW:
                #   update the row_norm of HW
                self.HW_row_norm[idx, after_nodes_ls[idx]] = torch.norm(x, p = 2, dim = 1).cpu().detach()

            x = self.dropout(F.elu(torch.spmm(adjs[idx], x)))
        return x


# Vertex Covering
class SuSKGCN(nn.Module):
    def __init__(self, encoder, num_classes, dropout, inp):
        super(SuSKGCN, self).__init__()
        self.encoder = encoder
        self.dropout = nn.Dropout(dropout)
        self.linear  = nn.Linear(self.encoder.nhid, num_classes)
    def forward(self, feat, adjs, after_nodes_ls = [], record_HW = False):
        x = self.encoder(feat, adjs, after_nodes_ls, record_HW)
        x = self.dropout(x)
        x = self.linear(x)
        return x

class GraphConvolutionVc(nn.Module):
    # x = AH, linear(x) = AHW
    def __init__(self, n_in, n_out, bias=True):
        super(GraphConvolutionVc, self).__init__()
        self.n_in  = n_in
        self.n_out = n_out
        self.linear = nn.Linear(n_in,  n_out)
    def forward(self, x, adj, D_inv = None, loc = None):
        out = self.linear(x)
        
        # when D_inv is None, do regular forward.
        if D_inv is None:
            return F.elu(torch.spmm(adj, out))
        else:
            # loc[0] = the last index of prev_nodes -> prev nodes: [:loc[0]]
            # loc[1] = the last index of pure_prev_nodes -> all nabr nodes: [loc[1]:]
            return F.elu(torch.spmm(adj, out[loc[1], :]) + 
                         torch.spmm(D_inv, out[:loc[0], :]))
        
        
class VcGCN(nn.Module):
    def __init__(self, nfeat, nhid, layers, dropout):
        super(VcGCN, self).__init__()
        self.layers = layers
        self.nhid = nhid
        self.gcs = nn.ModuleList()
        # use GraphConvolutionNys instead of GraphConvolution
        self.gcs.append(GraphConvolutionVc(nfeat,  nhid))
        self.dropout = nn.Dropout(dropout)
        for i in range(layers-1):
            self.gcs.append(GraphConvolutionVc(nhid,  nhid))
            
    def forward(self, x, adjs, D_inv_ls = None, loc_ls = None):

        for idx in range(len(self.gcs)):
            
            if D_inv_ls is None:
                # full-batch inference
                x = self.dropout(self.gcs[idx](x, adjs[idx]))
            else:
                # non-zero sampling
                x = self.dropout(self.gcs[idx](x, adjs[idx], D_inv_ls[idx], loc_ls[idx]))
        return x

class SuVcGCN(nn.Module):
    
    def __init__(self, encoder, num_classes, dropout, inp):
        super(SuVcGCN, self).__init__()
        self.encoder = encoder
        self.dropout = nn.Dropout(dropout)
        self.linear  = nn.Linear(self.encoder.nhid, num_classes)
    def forward(self, feat, adjs, D_inv_ls = None, loc_ls = None):
        x = self.encoder(feat, adjs, D_inv_ls, loc_ls)
        x = self.dropout(x)
        x = self.linear(x)
        return x


